Skip to content

examples/dreambooth: fix missing weighting chunk when using prior preservation in Flux and SD3 LoRA training#13743

Open
Dev-X25874 wants to merge 2 commits into
huggingface:mainfrom
Dev-X25874:fix/dreambooth-prior-preservation-weighting-chunk
Open

examples/dreambooth: fix missing weighting chunk when using prior preservation in Flux and SD3 LoRA training#13743
Dev-X25874 wants to merge 2 commits into
huggingface:mainfrom
Dev-X25874:fix/dreambooth-prior-preservation-weighting-chunk

Conversation

@Dev-X25874
Copy link
Copy Markdown
Contributor

What does this PR do?

When --with_prior_preservation is enabled, the training batch concatenates
instance and class (prior) samples, so every per-sample tensor —
model_pred, target, sigmas, and therefore weighting — has shape
(2 * train_batch_size, ...).

Inside the loss block, model_pred and target are correctly split via
torch.chunk(..., 2, dim=0), but weighting was never chunked. This means:

  • weighting (size 2B) is broadcast against model_pred_prior and
    target_prior (size B), producing a loss tensor of the wrong shape and
    applying incorrectly paired timestep weights to the prior loss term.
  • The instance loss term also gets weights from the full unsplit weighting
    instead of only the instance-sample half.

The correct pattern already exists in train_dreambooth_lora_flux2.py:

weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)

This PR applies the same fix to train_dreambooth_lora_flux.py and
train_dreambooth_lora_sd3.py, which were both missing it.

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

…target when using prior preservation (flux LoRA)
…target when using prior preservation (SD3 LoRA)
@github-actions github-actions Bot added examples size/S PR with diff < 50 LOC labels May 14, 2026
@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul, would you mind taking a look at this when you get a chance?

The bug is present in both train_dreambooth_lora_flux.py and train_dreambooth_lora_sd3.py — when --with_prior_preservation is enabled, weighting is never chunked alongside model_pred and target, causing incorrect timestep weights to be applied to the prior loss term. The fix already exists in train_dreambooth_lora_flux2.py (line 1832), so this PR simply backports it to the two older scripts. Happy to make any changes if needed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant